{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在整个机器学习过程中，除了训练模型外，应该就属数据预处理过程消耗的精力最多，数据预处理过程需要完成的任务包括数据读取、过滤、转换等等。为了将用户从繁杂的预处理操作中解放处理，更多地将精力放在算法建模上，TensorFlow中提供了data模块，这一模块以多种方式提供了数据读取、数据处理、数据保存等功能。本文重点是data模块中的Dataset对象。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1 创建"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于创建Dataset对象，官方文档中总结为两种方式，我将这两种方式细化后总结为4中方式：  \n",
    "\n",
    "**（1）通过Dataset中的range()方法创建包含一定序列的Dataset对象。**\n",
    "- **[range()](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset#range)**\n",
    "\n",
    "range()方法是Dataset内部定义的一个的静态方法，可以直接通过类名调用。另外，Dataset中的range()方法与Python本身内置的range()方法接受参数形式是一致的，可以接受range(begin)、range(begin, end)、range（begin, end, step）等多种方式传参。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensorflow.python.data.ops.dataset_ops.RangeDataset"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset1 = tf.data.Dataset.range(5)\n",
    "type(dataset1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注：RangeDataset是Dataset的一个子类。\n",
    "Dataset对象属于可迭代对象， 可通过循环进行遍历："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0, shape=(), dtype=int64)\n",
      "0\n",
      "tf.Tensor(1, shape=(), dtype=int64)\n",
      "1\n",
      "tf.Tensor(2, shape=(), dtype=int64)\n",
      "2\n",
      "tf.Tensor(3, shape=(), dtype=int64)\n",
      "3\n",
      "tf.Tensor(4, shape=(), dtype=int64)\n",
      "4\n"
     ]
    }
   ],
   "source": [
    "for i in dataset1:\n",
    "    print(i)\n",
    "    print(i.numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到，range()方法创建的Dataset对象内部每一个元素都以Tensor对象的形式存在，可以通过numpy()方法访问真实值。 \n",
    "- **[from_generator()](https://tensorflow.google.cn/guide/data#consuming_python_generators)**\n",
    "\n",
    "如果你觉得range()方法不够灵活，功能不够强大，那么你可以尝试使用from_generator()方法。from_generator()方法接收一个可调用的生成器函数最为参数，在遍历from_generator()方法返回的Dataset对象过程中不断生成新的数据，减少内存占用，这在大数据集中很有用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def count(stop):\n",
    "  i = 0\n",
    "  while i<stop:\n",
    "    print('第%s次调用……'%i)\n",
    "    yield i\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset2 = tf.data.Dataset.from_generator(count, args=[3], output_types=tf.int32, output_shapes = (), )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = iter(dataset2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "第0次调用……\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=46, shape=(), dtype=int32, numpy=0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "第1次调用……\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=47, shape=(), dtype=int32, numpy=1>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "第0次调用……\n",
      "tf.Tensor(0, shape=(), dtype=int32)\n",
      "0\n",
      "第1次调用……\n",
      "tf.Tensor(1, shape=(), dtype=int32)\n",
      "1\n",
      "第2次调用……\n",
      "tf.Tensor(2, shape=(), dtype=int32)\n",
      "2\n"
     ]
    }
   ],
   "source": [
    "for i in dataset2:\n",
    "    print(i)\n",
    "    print(i.numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（2）通过接收其他类型的集合类对象创建Dataset对象。**这里所说的集合类型对象包含Python内置的list、tuple，numpy中的ndarray等等。这种创建Dataset对象的方法大多通过from_tensors()和from_tensor_slices()两个方法实现。这两个方法很常用，重点说一说。\n",
    "- **[from_tensors()](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset#from_tensors)**  \n",
    "from_tensors()方法接受一个集合类型对象作为参数，返回值为一个TensorDataset类型对象，对象内容、shape因传入参数类型而异。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当接收参数为list或Tensor对象时，返回的情况是一样的，因为TensorFlow内部会将list先转为Tensor对象，然后实例化一个Dataset对象："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = [0,1,2,3,4]\n",
    "dataset1 = tf.data.Dataset.from_tensors(a)\n",
    "dataset1_n = tf.data.Dataset.from_tensors(np.array(a))\n",
    "dataset1_t = tf.data.Dataset.from_tensors(tf.constant(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<TensorDataset shapes: (5,), types: tf.int32>,\n",
       " <tf.Tensor: id=67, shape=(5,), dtype=int32, numpy=array([0, 1, 2, 3, 4], dtype=int32)>)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset1,next(iter(dataset1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<TensorDataset shapes: (5,), types: tf.int64>,\n",
       " <tf.Tensor: id=73, shape=(5,), dtype=int64, numpy=array([0, 1, 2, 3, 4])>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset1_n,next(iter(dataset1_n))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<TensorDataset shapes: (5,), types: tf.int32>,\n",
       " <tf.Tensor: id=79, shape=(5,), dtype=int32, numpy=array([0, 1, 2, 3, 4], dtype=int32)>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset1_t,next(iter(dataset1_t))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "多维结构也是一样的："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = [0,1,2,3,4]\n",
    "b = [5,6,7,8,9]\n",
    "dataset2 = tf.data.Dataset.from_tensors([a,b])\n",
    "dataset2_n = tf.data.Dataset.from_tensors(np.array([a,b]))\n",
    "dataset2_t = tf.data.Dataset.from_tensors(tf.constant([a,b]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<TensorDataset shapes: (2, 5), types: tf.int32>,\n",
       " <tf.Tensor: id=91, shape=(2, 5), dtype=int32, numpy=\n",
       " array([[0, 1, 2, 3, 4],\n",
       "        [5, 6, 7, 8, 9]], dtype=int32)>)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset2,next(iter(dataset2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<TensorDataset shapes: (2, 5), types: tf.int64>,\n",
       " <tf.Tensor: id=97, shape=(2, 5), dtype=int64, numpy=\n",
       " array([[0, 1, 2, 3, 4],\n",
       "        [5, 6, 7, 8, 9]])>)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset2_n,next(iter(dataset2_n))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<TensorDataset shapes: (2, 5), types: tf.int32>,\n",
       " <tf.Tensor: id=103, shape=(2, 5), dtype=int32, numpy=\n",
       " array([[0, 1, 2, 3, 4],\n",
       "        [5, 6, 7, 8, 9]], dtype=int32)>)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset2_t,next(iter(dataset2_t))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当接收参数为数组就不一样了，此时Dataset内部内容为一个tuple，tuple的元素是原来tuple元素转换为的Tensor对象："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = [0,1,2,3,4]\n",
    "b = [5,6,7,8,9]\n",
    "dataset3 = tf.data.Dataset.from_tensors((a,b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'tuple'>\n",
      "(<tf.Tensor: id=112, shape=(5,), dtype=int32, numpy=array([0, 1, 2, 3, 4], dtype=int32)>, <tf.Tensor: id=113, shape=(5,), dtype=int32, numpy=array([5, 6, 7, 8, 9], dtype=int32)>)\n",
      "tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)\n",
      "tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset3:\n",
    "    print(type(i))\n",
    "    print(i)\n",
    "    for j in i:\n",
    "        print(j)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- **[from_tensor_slices()](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset#from_tensor_slices)**  \n",
    "from_tensor_slices()方法返回一个TensorSliceDataset类对象，TensorSliceDataset对象比from_tensors()方法返回的TensorDataset对象支持更加丰富的操作，例如batch操作等，因此在实际应用中更加广泛。返回的TensorSliceDataset对象内容、shape因传入参数类型而异。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当传入一个list时，时将list中元素逐个转换为Tensor对象然后依次放入Dataset中，所以Dataset中有多个Tensor对象："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = [0,1,2,3,4]\n",
    "dataset1 = tf.data.Dataset.from_tensor_slices(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<TensorSliceDataset shapes: (), types: tf.int32>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 --> tf.Tensor(0, shape=(), dtype=int32)\n",
      "1 --> tf.Tensor(1, shape=(), dtype=int32)\n",
      "2 --> tf.Tensor(2, shape=(), dtype=int32)\n",
      "3 --> tf.Tensor(3, shape=(), dtype=int32)\n",
      "4 --> tf.Tensor(4, shape=(), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "for i,elem in enumerate(dataset1):\n",
    "    print(i, '-->', elem)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = [0,1,2,3,4]\n",
    "b = [5,6,7,8,9]\n",
    "dataset2 = tf.data.Dataset.from_tensor_slices([a,b])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<TensorSliceDataset shapes: (5,), types: tf.int32>"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 --> tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)\n",
      "1 --> tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "for i,elem in enumerate(dataset2):\n",
    "    print(i, '-->', elem)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当传入参数为tuple时，会将tuple中各元素转换为Tensor对象，然后将第一维度对应位置的切片进行重新组合成一个tuple依次放入到Dataset中，所以在返回的Dataset中有多个tuple。这种形式在对训练集和测试集进行重新组合是非常实用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = [0,1,2,3,4]\n",
    "b = [5,6,7,8,9]\n",
    "dataset1 = tf.data.Dataset.from_tensor_slices((a,b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<TensorSliceDataset shapes: ((), ()), types: (tf.int32, tf.int32)>"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(<tf.Tensor: id=143, shape=(), dtype=int32, numpy=0>, <tf.Tensor: id=144, shape=(), dtype=int32, numpy=5>)\n",
      "(<tf.Tensor: id=145, shape=(), dtype=int32, numpy=1>, <tf.Tensor: id=146, shape=(), dtype=int32, numpy=6>)\n",
      "(<tf.Tensor: id=147, shape=(), dtype=int32, numpy=2>, <tf.Tensor: id=148, shape=(), dtype=int32, numpy=7>)\n",
      "(<tf.Tensor: id=149, shape=(), dtype=int32, numpy=3>, <tf.Tensor: id=150, shape=(), dtype=int32, numpy=8>)\n",
      "(<tf.Tensor: id=151, shape=(), dtype=int32, numpy=4>, <tf.Tensor: id=152, shape=(), dtype=int32, numpy=9>)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset1:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "c = ['a','b','c','d','e']\n",
    "dataset3 = tf.data.Dataset.from_tensor_slices((a,b,c))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<TensorSliceDataset shapes: ((), (), ()), types: (tf.int32, tf.int32, tf.string)>"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(<tf.Tensor: id=162, shape=(), dtype=int32, numpy=0>, <tf.Tensor: id=163, shape=(), dtype=int32, numpy=5>, <tf.Tensor: id=164, shape=(), dtype=string, numpy=b'a'>)\n",
      "(<tf.Tensor: id=165, shape=(), dtype=int32, numpy=1>, <tf.Tensor: id=166, shape=(), dtype=int32, numpy=6>, <tf.Tensor: id=167, shape=(), dtype=string, numpy=b'b'>)\n",
      "(<tf.Tensor: id=168, shape=(), dtype=int32, numpy=2>, <tf.Tensor: id=169, shape=(), dtype=int32, numpy=7>, <tf.Tensor: id=170, shape=(), dtype=string, numpy=b'c'>)\n",
      "(<tf.Tensor: id=171, shape=(), dtype=int32, numpy=3>, <tf.Tensor: id=172, shape=(), dtype=int32, numpy=8>, <tf.Tensor: id=173, shape=(), dtype=string, numpy=b'd'>)\n",
      "(<tf.Tensor: id=174, shape=(), dtype=int32, numpy=4>, <tf.Tensor: id=175, shape=(), dtype=int32, numpy=9>, <tf.Tensor: id=176, shape=(), dtype=string, numpy=b'e'>)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset3:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对比总结一下from_generator(）、from_tensor()、from_tensor_slices()这三个方法：\n",
    "- from_tensors()在形式上与from_tensor_slices()很相似，但其实from_tensors()方法出场频率上比from_tensor_slices()差太多，因为from_tensor_slices()的功能更加符合实际需求，且返回的TensorSliceDataset对象也提供更多的数据处理功能。from_tensors()方法在接受list类型参数时，将整个list转换为Tensor对象放入Dataset中，当接受参数为tuple时，将tuple内元素转换为Tensor对象，然后将这个tuple放入Dataset中。\n",
    "- from_generator(）方法接受一个可调用的生成器函数作为参数，在遍历Dataset对象时，通过通用生成器函数继续生成新的数据供训练和测试模型使用，这在大数据集合中很实用。\n",
    "- from_tensor_slices()方法接受参数为list时，将list各元素依次转换为Tensor对象，然后依次放入Dataset中；更为常见的情况是接受的参数为tuple，在这种情况下，要求tuple中各元素第一维度长度必须相等，from_tensor_slices()方法会将tuple各元素第一维度进行拆解，然后将对应位置的元素进行重组成一个个tuple依次放入Dataset中，这一功能在重新组合数据集属性和标签时很有用。另外，from_tensor_slices()方法返回的TensorSliceDataset对象支持batch、shuffle等等功能对数据进一步处理。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（3）通过读取磁盘中的文件（文本、图片等等）来创建Dataset。**tf.data中提供了TextLineDataset、TFRecordDataset等对象来实现此功能。这部分内容比较多，也比较重要，我打算后续用专门一篇博客来总结这部分内容。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2 功能函数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（1）take()**  \n",
    "\n",
    "功能：用于返回一个新的Dataset对象，新的Dataset对象包含的数据是原Dataset对象的子集。 \n",
    "\n",
    "参数：  \n",
    "- count：整型，用于指定前count条数据用于创建新的Dataset对象，如果count为-1或大于原Dataset对象的size,则用原Dataset对象的全部数据创建新的对象。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.range(10)\n",
    "dataset_take = dataset.take(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0, shape=(), dtype=int64)\n",
      "tf.Tensor(1, shape=(), dtype=int64)\n",
      "tf.Tensor(2, shape=(), dtype=int64)\n",
      "tf.Tensor(3, shape=(), dtype=int64)\n",
      "tf.Tensor(4, shape=(), dtype=int64)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset_take:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（2）batch()**  \n",
    "\n",
    "功能：将Dataset中连续的数据分割成批。 \n",
    "\n",
    "参数：  \n",
    "- batch_size：在单个批次中合并的此数据集的连续元素数。\n",
    "- drop_remainder：如果最后一批的数据量少于指定的batch_size，是否抛弃最后一批，默认为False，表示不抛弃。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.range(11)\n",
    "dataset_batch = dataset.batch(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor([0 1 2], shape=(3,), dtype=int64)\n",
      "tf.Tensor([3 4 5], shape=(3,), dtype=int64)\n",
      "tf.Tensor([6 7 8], shape=(3,), dtype=int64)\n",
      "tf.Tensor([ 9 10], shape=(2,), dtype=int64)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset_batch:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_batch = dataset.batch(3,drop_remainder=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor([0 1 2], shape=(3,), dtype=int64)\n",
      "tf.Tensor([3 4 5], shape=(3,), dtype=int64)\n",
      "tf.Tensor([6 7 8], shape=(3,), dtype=int64)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset_batch:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = tf.random.uniform((10,3),maxval=100, dtype=tf.int32)\n",
    "train_y = tf.range(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(<tf.Tensor: id=236, shape=(3,), dtype=int32, numpy=array([81, 53, 85], dtype=int32)>, <tf.Tensor: id=237, shape=(), dtype=int32, numpy=0>)\n",
      "(<tf.Tensor: id=238, shape=(3,), dtype=int32, numpy=array([13,  7, 25], dtype=int32)>, <tf.Tensor: id=239, shape=(), dtype=int32, numpy=1>)\n",
      "(<tf.Tensor: id=240, shape=(3,), dtype=int32, numpy=array([83, 25, 55], dtype=int32)>, <tf.Tensor: id=241, shape=(), dtype=int32, numpy=2>)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset.take(3):\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_batch = dataset.batch(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(<tf.Tensor: id=250, shape=(4, 3), dtype=int32, numpy=\n",
      "array([[81, 53, 85],\n",
      "       [13,  7, 25],\n",
      "       [83, 25, 55],\n",
      "       [53, 41, 11]], dtype=int32)>, <tf.Tensor: id=251, shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>)\n",
      "(<tf.Tensor: id=252, shape=(4, 3), dtype=int32, numpy=\n",
      "array([[41, 58, 39],\n",
      "       [44, 68, 55],\n",
      "       [52, 34, 22],\n",
      "       [66, 39,  5]], dtype=int32)>, <tf.Tensor: id=253, shape=(4,), dtype=int32, numpy=array([4, 5, 6, 7], dtype=int32)>)\n",
      "(<tf.Tensor: id=254, shape=(2, 3), dtype=int32, numpy=\n",
      "array([[73,  8, 20],\n",
      "       [67, 71, 98]], dtype=int32)>, <tf.Tensor: id=255, shape=(2,), dtype=int32, numpy=array([8, 9], dtype=int32)>)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset_batch:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为什么在训练模型时要将Dataset分割成一个个batch呢？\n",
    "- 对于小数据集是否使用batch关系不大，但是对于大数据集如果不分割成batch意味着将这个数据集一次性输入模型中，容易造成内存爆炸。\n",
    "- 通过并行化提高内存的利用率。就是尽量让你的GPU满载运行，提高训练速度。\n",
    "- 单个epoch的迭代次数减少了，参数的调整也慢了，假如要达到相同的识别精度，需要更多的epoch。\n",
    "- 适当Batch Size使得梯度下降方向更加准确。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（3）padded_batch()**  \n",
    "\n",
    "功能： batch()的进阶版，可以对shape不一致的连续元素进行分批。  \n",
    "\n",
    "参数：  \n",
    "- batch_size：在单个批次中合并的此数据集的连续元素个数。\n",
    "- padded_shapes：tf.TensorShape或其他描述tf.int64矢量张量对象，表示在批处理之前每个输入元素的各个组件应填充到的形状。如果参数中有None，则表示将填充为每个批次中该尺寸的最大尺寸。\n",
    "- padding_values：要用于各个组件的填充值。默认值0用于数字类型，字符串类型则默认为空字符。\n",
    "- drop_remainder：如果最后一批的数据量少于指定的batch_size，是否抛弃最后一批，默认为False，表示不抛弃。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.range(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_padded = dataset.padded_batch(4, padded_shapes=(None,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0 0 0]\n",
      " [1 0 0]\n",
      " [2 2 0]\n",
      " [3 3 3]]\n",
      "---------------------\n",
      "[[4 4 4 4 0 0 0]\n",
      " [5 5 5 5 5 0 0]\n",
      " [6 6 6 6 6 6 0]\n",
      " [7 7 7 7 7 7 7]]\n",
      "---------------------\n",
      "[[8 8 8 8 8 8 8 8 0]\n",
      " [9 9 9 9 9 9 9 9 9]]\n",
      "---------------------\n"
     ]
    }
   ],
   "source": [
    "for batch in dataset_padded:\n",
    "    print(batch.numpy())\n",
    "    print('---------------------')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_padded = dataset.padded_batch(4, padded_shapes=(10,),padding_values=tf.constant(9,dtype=tf.int64))  # 修改填充形状和填充元素"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[9 9 9 9 9 9 9 9 9 9]\n",
      " [1 9 9 9 9 9 9 9 9 9]\n",
      " [2 2 9 9 9 9 9 9 9 9]\n",
      " [3 3 3 9 9 9 9 9 9 9]]\n",
      "---------------------\n",
      "[[4 4 4 4 9 9 9 9 9 9]\n",
      " [5 5 5 5 5 9 9 9 9 9]\n",
      " [6 6 6 6 6 6 9 9 9 9]\n",
      " [7 7 7 7 7 7 7 9 9 9]]\n",
      "---------------------\n",
      "[[8 8 8 8 8 8 8 8 9 9]\n",
      " [9 9 9 9 9 9 9 9 9 9]]\n",
      "---------------------\n"
     ]
    }
   ],
   "source": [
    "for batch in dataset_padded:\n",
    "    print(batch.numpy())\n",
    "    print('---------------------')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（4）map()**  \n",
    "\n",
    "功能： 以dataset中每一位元素为参数执行pap_func()方法，这一功能在数据预处理中修改dataset中元素是很实用。\n",
    "\n",
    "参数：\n",
    "- map_func:回调方法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "def change_dtype(t):  # 将类型修改为int32\n",
    "    return tf.cast(t,dtype=tf.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.range(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0, shape=(), dtype=int64)\n",
      "tf.Tensor(1, shape=(), dtype=int64)\n",
      "tf.Tensor(2, shape=(), dtype=int64)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_map = dataset.map(change_dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0, shape=(), dtype=int32)\n",
      "tf.Tensor(1, shape=(), dtype=int32)\n",
      "tf.Tensor(2, shape=(), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset_map:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "map_func的参数必须对应dataset中的元素类型，例如，如果dataset中元素是tuple，map_func可以这么定义："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "def change_dtype_2(t1,t2):\n",
    "    return t1/10,tf.cast(t2,dtype=tf.int32)*(-1)  # 第一位元素除以10，第二为元素乘以-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.from_tensor_slices((tf.range(3),tf.range(3)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_map = dataset.map(change_dtype_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(<tf.Tensor: id=347, shape=(), dtype=float64, numpy=0.0>, <tf.Tensor: id=348, shape=(), dtype=int32, numpy=0>)\n",
      "(<tf.Tensor: id=349, shape=(), dtype=float64, numpy=0.1>, <tf.Tensor: id=350, shape=(), dtype=int32, numpy=-1>)\n",
      "(<tf.Tensor: id=351, shape=(), dtype=float64, numpy=0.2>, <tf.Tensor: id=352, shape=(), dtype=int32, numpy=-2>)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset_map:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（5）filter()**  \n",
    "\n",
    "功能：对Dataset中每一个执行指定过滤方法进行过滤，返回过滤后的Dataset对象  \n",
    "\n",
    "参数：\n",
    "- predicate：过滤方法，返回值必须为True或False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.range(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_func(t):  # 过滤出值为偶数的元素\n",
    "    if t % 2 == 0:\n",
    "        return True\n",
    "    else:\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_filter = dataset.filter(filter_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0, shape=(), dtype=int64)\n",
      "tf.Tensor(2, shape=(), dtype=int64)\n",
      "tf.Tensor(4, shape=(), dtype=int64)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset_filter:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（6）shuffle()**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "功能：随机打乱数据  \n",
    "\n",
    "参数：\n",
    "- buffer_size：缓冲区大小，姑且认为是混乱程度吧，当值为1时，完全不打乱，当值为整个Dataset元素总数时，完全打乱。\n",
    "- seed：将用于创建分布的随机种子。\n",
    "- reshuffle_each_iteration：如果为true，则表示每次迭代数据集时都应进行伪随机重排，默认为True。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.range(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_s = dataset.shuffle(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0, shape=(), dtype=int64)\n",
      "tf.Tensor(1, shape=(), dtype=int64)\n",
      "tf.Tensor(2, shape=(), dtype=int64)\n",
      "tf.Tensor(3, shape=(), dtype=int64)\n",
      "tf.Tensor(4, shape=(), dtype=int64)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset_s:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_s = dataset.shuffle(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0, shape=(), dtype=int64)\n",
      "tf.Tensor(4, shape=(), dtype=int64)\n",
      "tf.Tensor(1, shape=(), dtype=int64)\n",
      "tf.Tensor(2, shape=(), dtype=int64)\n",
      "tf.Tensor(3, shape=(), dtype=int64)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset_s:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（7）repeat()**  \n",
    "\n",
    "功能：对Dataset中的数据进行重复，以创建新的Dataset\n",
    "\n",
    "参数：\n",
    "- count：重复次数，默认为None，表示不重复，当值为-1时，表示无限重复。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.range(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_repeat = dataset.repeat(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0, shape=(), dtype=int64)\n",
      "tf.Tensor(1, shape=(), dtype=int64)\n",
      "tf.Tensor(2, shape=(), dtype=int64)\n",
      "tf.Tensor(0, shape=(), dtype=int64)\n",
      "tf.Tensor(1, shape=(), dtype=int64)\n",
      "tf.Tensor(2, shape=(), dtype=int64)\n",
      "tf.Tensor(0, shape=(), dtype=int64)\n",
      "tf.Tensor(1, shape=(), dtype=int64)\n",
      "tf.Tensor(2, shape=(), dtype=int64)\n"
     ]
    }
   ],
   "source": [
    "for i in dataset_repeat:\n",
    "    print(i)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "study_python",
   "language": "python",
   "name": "study_python"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
